{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LightHuBERT\n",
    "\n",
    "The LightHuBERT tutorial support features as follows.\n",
    "\n",
    "1. Load different checkpoints to the LightHuBERT architecture, such as base supernet, small supernet, and stage 1.\n",
    "2. Sample a subnet and set the subnet.\n",
    "3. Conduct inference of the subnet with a random input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import torch\n",
    "from lighthubert import LightHuBERT, LightHuBERTConfig\n",
    "\n",
    "device=\"cuda:2\"\n",
    "wav_input_16khz = torch.randn(1,10000).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LightHuBERT Base Supernet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-03-29 20:52:53 | INFO | lighthubert.lighthubert | predicting heads: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]\n",
      "2022-03-29 20:52:53 | INFO | lighthubert.lighthubert | search space (6,530,347,008 subnets): {'atten_dim': [512, 640, 768], 'embed_dim': [512, 640, 768], 'ffn_ratio': [3.5, 4.0], 'heads_num': [8, 10, 12], 'layer_num': [12]}\n",
      "2022-03-29 20:52:53 | INFO | lighthubert.lighthubert | min subnet (41 Params): {'atten_dim': [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512], 'embed_dim': 512, 'ffn_embed': [1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792], 'heads_num': [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n",
      "2022-03-29 20:52:53 | INFO | lighthubert.lighthubert | max subnet (94 Params): {'atten_dim': [768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768], 'embed_dim': 768, 'ffn_embed': [3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072], 'heads_num': [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_embs_concat', 'final_proj.weight', 'final_proj.bias'])\n",
      "subnet (Params 85M) | {'atten_dim': [512, 640, 512, 640, 768, 512, 640, 512, 512, 640, 768, 640], 'embed_dim': 768, 'ffn_embed': [3072, 2688, 3072, 2688, 2688, 3072, 2688, 3072, 3072, 2688, 3072, 2688], 'heads_num': [8, 10, 8, 10, 12, 8, 10, 8, 8, 10, 12, 10], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n",
      "Representation at bottom hidden states: True\n"
     ]
    }
   ],
   "source": [
    "# checkpoint = torch.load('/path/to/lighthubert_base.pt')\n",
    "checkpoint = torch.load('/workspace/projects/lighthubert/checkpoints/lighthubert_base.pt')\n",
    "cfg = LightHuBERTConfig(checkpoint['cfg']['model'])\n",
    "cfg.supernet_type = 'base'\n",
    "model = LightHuBERT(cfg)\n",
    "model = model.to(device)\n",
    "model = model.eval()\n",
    "print(model.load_state_dict(checkpoint['model'], strict=False))\n",
    "\n",
    "# (optional) set a subnet\n",
    "subnet = model.supernet.sample_subnet()\n",
    "model.set_sample_config(subnet)\n",
    "params = model.calc_sampled_param_num()\n",
    "print(f\"subnet (Params {params / 1e6:.0f}M) | {subnet}\")\n",
    "\n",
    "# extract the the representation of last layer\n",
    "rep = model.extract_features(wav_input_16khz)[0]\n",
    "\n",
    "# extract the the representation of each layer\n",
    "hs = model.extract_features(wav_input_16khz, ret_hs=True)[0]\n",
    "\n",
    "print(f\"Representation at bottom hidden states: {torch.allclose(rep, hs[-1])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LightHuBERT Small Supernet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-03-29 20:53:04 | INFO | lighthubert.lighthubert | LightHuBERT Config: <lighthubert.lighthubert.LightHuBERTConfig object at 0x7f078337e430>\n",
      "2022-03-29 20:53:06 | INFO | lighthubert.lighthubert | predicting heads: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]\n",
      "2022-03-29 20:53:06 | INFO | lighthubert.lighthubert | search space (951,892,141,473 subnets): {'atten_dim': [256, 384, 512], 'embed_dim': [256, 384, 512], 'ffn_ratio': [3.0, 3.5, 4.0], 'heads_num': [4, 6, 8], 'layer_num': [10, 11, 12]}\n",
      "2022-03-29 20:53:06 | INFO | lighthubert.lighthubert | min subnet (11 Params): {'atten_dim': [256, 256, 256, 256, 256, 256, 256, 256, 256, 256], 'embed_dim': 256, 'ffn_embed': [768, 768, 768, 768, 768, 768, 768, 768, 768, 768], 'heads_num': [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 'layer_num': 10, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n",
      "2022-03-29 20:53:06 | INFO | lighthubert.lighthubert | max subnet (44 Params): {'atten_dim': [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512], 'embed_dim': 512, 'ffn_embed': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'heads_num': [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_embs_concat', 'final_proj.weight', 'final_proj.bias'])\n",
      "subnet (Params 38M) | {'atten_dim': [256, 512, 256, 256, 384, 384, 256, 512, 256, 512, 384, 512], 'embed_dim': 512, 'ffn_embed': [1792, 1536, 1536, 2048, 1792, 1536, 1536, 2048, 1792, 2048, 2048, 2048], 'heads_num': [4, 8, 4, 4, 6, 6, 4, 8, 4, 8, 6, 8], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n",
      "Representation at bottom hidden states: True\n"
     ]
    }
   ],
   "source": [
    "# checkpoint = torch.load('/path/to/lighthubert_small.pt')\n",
    "checkpoint = torch.load('/workspace/projects/lighthubert/checkpoints/lighthubert_small.pt')\n",
    "cfg = LightHuBERTConfig(checkpoint['cfg']['model'])\n",
    "cfg.supernet_type = 'small'\n",
    "model = LightHuBERT(cfg)\n",
    "model = model.to(device)\n",
    "model = model.eval()\n",
    "print(model.load_state_dict(checkpoint['model'], strict=False))\n",
    "\n",
    "# (optional) set a subnet\n",
    "subnet = model.supernet.sample_subnet()\n",
    "model.set_sample_config(subnet)\n",
    "params = model.calc_sampled_param_num()\n",
    "print(f\"subnet (Params {params / 1e6:.0f}M) | {subnet}\")\n",
    "\n",
    "# extract the the representation of last layer\n",
    "rep = model.extract_features(wav_input_16khz)[0]\n",
    "\n",
    "# extract the the representation of each layer\n",
    "hs = model.extract_features(wav_input_16khz, ret_hs=True)[0]\n",
    "\n",
    "print(f\"Representation at bottom hidden states: {torch.allclose(rep, hs[-1])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LightHuBERT Stage 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-03-29 20:53:20 | INFO | lighthubert.lighthubert | LightHuBERT Config: <lighthubert.lighthubert.LightHuBERTConfig object at 0x7f097862fe80>\n",
      "2022-03-29 20:53:23 | INFO | lighthubert.lighthubert | predicting heads: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]\n",
      "2022-03-29 20:53:23 | INFO | lighthubert.lighthubert | search space (6,530,347,008 subnets): {'atten_dim': [512, 640, 768], 'embed_dim': [512, 640, 768], 'ffn_ratio': [3.5, 4.0], 'heads_num': [8, 10, 12], 'layer_num': [12]}\n",
      "2022-03-29 20:53:23 | INFO | lighthubert.lighthubert | min subnet (41 Params): {'atten_dim': [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512], 'embed_dim': 512, 'ffn_embed': [1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792], 'heads_num': [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n",
      "2022-03-29 20:53:23 | INFO | lighthubert.lighthubert | max subnet (94 Params): {'atten_dim': [768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768], 'embed_dim': 768, 'ffn_embed': [3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072], 'heads_num': [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_embs_concat', 'final_proj.weight', 'final_proj.bias'])\n",
      "subnet (Params 94M) | {'atten_dim': [768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768], 'embed_dim': 768, 'ffn_embed': [3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072], 'heads_num': [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}\n",
      "Representation at bottom hidden states: True\n"
     ]
    }
   ],
   "source": [
    "# checkpoint = torch.load('/path/to/lighthubert_stage1.pt')\n",
    "checkpoint = torch.load('/workspace/projects/lighthubert/checkpoints/lighthubert_stage1.pt')\n",
    "cfg = LightHuBERTConfig(checkpoint['cfg']['model'])\n",
    "cfg.supernet_type = 'base'\n",
    "model = LightHuBERT(cfg)\n",
    "model = model.to(device)\n",
    "model = model.eval()\n",
    "print(model.load_state_dict(checkpoint['model'], strict=False))\n",
    "\n",
    "# (optional) set a subnet\n",
    "subnet = model.supernet.max_subnet\n",
    "model.set_sample_config(subnet)\n",
    "params = model.calc_sampled_param_num()\n",
    "print(f\"subnet (Params {params / 1e6:.0f}M) | {subnet}\")\n",
    "\n",
    "# extract the the representation of last layer\n",
    "rep = model.extract_features(wav_input_16khz)[0]\n",
    "\n",
    "# extract the the representation of each layer\n",
    "hs = model.extract_features(wav_input_16khz, ret_hs=True)[0]\n",
    "\n",
    "print(f\"Representation at bottom hidden states: {torch.allclose(rep, hs[-1])}\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "a07f300b2e0ac9383d668273c6d99cb9a788a1be8923260608fa0c116eacfb4c"
  },
  "kernelspec": {
   "display_name": "Python 3.8.0 ('speech_transducers')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
